import numpy as np, pandas as pd
import matplotlib.pyplot as plt
from scipy import interpolate, integrate, optimize
import utils
import plot_specs as specs
import menu_cost as mc
import time_dependent as td
from matplotlib.lines import Line2D

# %% data

df = pd.read_csv('empirical_distribution.csv')
x, y = df['value'].to_numpy(), df['fraction'].to_numpy()

# convert to density
y /= (x[1]-x[0])

freq = .1141
beta = 0.99 ** (1 / 3) # quarterly beta=0.99 from paper, translated to monthly

# %% functions


def plot_jacobian(J, ax, xaxis=None, cols=[0, 10, 20], cut=None, color=None, linestyle='-', alpha=None):
    if xaxis is None:
        xaxis = np.arange(len(J))
    if cut is None:
        cut = max(cols) + 61
    for i in range(len(cols)):
        c = cols[i]
        ax.plot(xaxis[:cut], J[:cut, c], color=color, linestyle=linestyle, alpha=alpha)


def forward_mapping(Lambda, sigma_eps, xbar):
    """Return mapping T from g to g_+ corresponding to forward iteration on beginning-of-period densities"""

    def T(g):
        # frequency of price adjustment is integral of Lambda(x)*g(x)
        freq = 2 * utils.integrate(lambda x: Lambda(x) * g(x), 0, xbar)

        # freq = 2*integrate.quad(lambda x: Lambda(x)*g(x), 0, xbar, epsabs=1E-12)[0]

        def gplus(xs):
            # if at x' today, take expectation over (1-Lambda(x))*g(x) for x ~ N(x', sigma_eps)
            # this gives component of density coming from yesterday's non-resetters
            noreset = utils.expectations_normal(lambda x: (1 - Lambda(x)) * g(x), xs, sigma_eps)

            # add freq times density N(0, sigma_eps), from yesterday's resetters
            reset = freq * utils.normal_pdf(xs, sigma_eps)

            return noreset + reset

        return gplus

    return T


def g_from_Lambda(Lambda, sigma_eps, xbar, Ncubic=200, tol=1E-8, maxit=200):
    T = forward_mapping(Lambda, sigma_eps, xbar)

    # grid for xs, and initial guess that the density is a uniform
    xs = np.linspace(-xbar, xbar, Ncubic)
    g = lambda x: 1 / (2 * xbar)

    # each iteration:
    for it in range(maxit):
        # apply mapping and evaluate at xs
        gxs_new = T(g)(xs)

        # obtain new interpolated g, normalize to ensure integral of 1
        graw = interpolate.CubicSpline(xs, gxs_new)
        gint = graw.integrate(-xbar, xbar)
        g = lambda x: graw(x) / gint

        if it > 0 and np.max(np.abs(gxs_new - gxs)) < tol:
            return g
        elif it < maxit - 1:
            gxs = gxs_new

    raise ValueError(f'No convergence error is {np.max(np.abs(gxs_new - gxs))}!')


def Lambda_from_params(params):
    p0, p2, s, sigma = params
    h = lambda x: p0 + p2*x**2 - s*utils.normal_pdf(x, sigma)
    return lambda x: 1/(1+np.exp(-h(x)))


xbar = 5
def errors_from_params(params, verbose=False):
    if verbose:
        print(params)
    sigma_eps, *params_Lambda = params
    Lambda = Lambda_from_params(params_Lambda)
    if Lambda(xbar) < 0.8:
        return np.full(len(x), 1E6)
    g = g_from_Lambda(Lambda, sigma_eps, xbar)
    err = g(x)*Lambda(x) - freq*y
    if verbose:
        print(np.linalg.norm(err))
    return err


def expectation_mapping(Lambda, sigma_eps):
    """Return mapping T from function m(x') of tomorrow's end-of-period x'
    to expectation E[m(x')|x] conditional on today's end-of-period x"""

    def T(m):
        def Em(xs):
            # if at x, take expectation over (1-Lambda(x'))m(x') for x' ~ N(x, sigma_eps)
            # this gives component of expectation from tomorrow's non-resetters
            noreset = utils.expectations_normal(lambda xp: (1 - Lambda(xp)) * m(xp), xs, sigma_eps)

            # take expectation over Lambda(x') to get total adj prob, multiply by m(0)
            # this gives component of expectation from tomorrow's resetters
            reset = m(0) * utils.expectations_normal(lambda xp: Lambda(xp), xs, sigma_eps)
            return noreset + reset

        return Em

    return T


# %% compute hazards

guess = [0.25, -1, 1, 1, 0.1]
optimized = optimize.least_squares(errors_from_params, guess, method='lm', kwargs={'verbose': False})

params = optimized.x
sigma_eps = params[0]

Lambda = Lambda_from_params(params[1:])
g = g_from_Lambda(Lambda, sigma_eps, xbar)

freq_model = integrate.quad(lambda x: g(x)*Lambda(x), -xbar, xbar, epsabs=1E-13)[0]


# %% figures 13a, 13b

fig, ax = plt.subplots(figsize=specs.figsize_standard)
ax.plot(x, freq*y, 'ro', ms=5, color='firebrick', label='Data')
ax.plot(x, g(x)*Lambda(x), color='mediumblue', label='Fitted')
ax.set_xlabel('Standardized price change')
ax.set_ylabel('Density')
ax.set_xlim([-4, 4])
ax.legend()
specs.save_figure(fig, 'figure_13_a', close=True)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
ax.plot(x, Lambda(x), color='mediumblue')
ax.set_ylabel('Monthly price adjustment hazard')
ax.set_xlim([-4, 4])
ax.set_xlabel('Standardized price gap')
specs.save_figure(fig, 'figure_13_b', close=True)


# %% compute jacobians

Texp = expectation_mapping(Lambda, sigma_eps)

xs = np.linspace(-xbar, xbar, 1000)

T = 50
Es = []
Exs = xs # initialize with identity

for t in range(T):
    Et = interpolate.CubicSpline(xs, Exs)
    Exs = Texp(Et)(xs)
    Es.append(Et)


Phi_intensive = np.array([Et.derivative()(0) for Et in Es])
TPsi = 1600
Psi_intensive = utils.Psi_for_td(Phi_intensive, beta, TPsi)
weight_intensive = freq_model * Phi_intensive.sum()

wl, xl = utils.legendre_quick(60, 0, xbar)

Et_x = np.empty((len(Es), len(xl)))
for t, Et in enumerate(Es):
    Et_x[t, :] = Et(xl)

Lambdap_x = (Lambda(xl+1E-4) - Lambda(xl-1E-4))/2E-4
weight_x = 2 * Lambdap_x * g(xl) * Et_x.sum(axis=0)

Phit_x = Et_x / xl

F_x = [utils.F_for_td(Phit, beta) for Phit in Phit_x.T]
F_extensive = sum(w*weight*F for w, weight, F in zip(wl, weight_x, F_x))
Psi_extensive = utils.J_from_F(F_extensive, T=TPsi)
weight_extensive = wl @ weight_x

Psi_extensive /= weight_extensive

Psi = (weight_intensive*Psi_intensive + weight_extensive*Psi_extensive)/(weight_intensive + weight_extensive)
K = np.linalg.solve(np.eye(TPsi) - Psi, Psi)
K[1:, :] -= K[:-1, :]

K = K[:300, :300]
Psi = Psi[:300, :300]

J_approx_real, theta, _, _, _ = mc.approx_jacobian(K, beta, Price=False, Nominal=False, Absolute=True)
J_approx_nom = td.calvo_jacobian(theta, beta, len(J_approx_real))


# %% figures 13c, 13d, F1


custom_lines = [Line2D([0], [0], color='mediumblue', linestyle='-'),
                Line2D([0], [0], color='firebrick', linestyle='--')]
legend = ['Menu cost', 'Calvo']

xaxis = np.arange(121) / 3

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(Psi, ax, xaxis=xaxis, cols=[0, 30, 60], color='mediumblue')
plot_jacobian(J_approx_nom, ax, xaxis=xaxis, cols=[0, 30, 60], color='firebrick', linestyle='--')
ax.set_xlim([0, 40])
ax.set_xlabel('Quarters')
ax.set_ylabel('Price level')
specs.save_figure(fig, 'figure_13_c', close=True)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
plot_jacobian(9 * K, ax, xaxis=xaxis, cols=[0, 30, 60], color='mediumblue')
plot_jacobian(9 * J_approx_real, ax, xaxis=xaxis, cols=[0, 30, 60], color='firebrick', linestyle='--')
ax.set_xlim([0, 40])
ax.set_xlabel('Quarters')
ax.set_ylabel('Inflation')
ax.legend(custom_lines, legend)
specs.save_figure(fig, 'figure_13_d', close=True)

fig, ax = plt.subplots(figsize=specs.figsize_standard)
ax.plot(np.concatenate([-np.flip(xl), xl]), np.concatenate([np.flip(weight_x), weight_x]) / 2, color='green')
ax.set_xlabel('Price gap')
specs.save_figure(fig, 'figure_F1', close=True)
